# improved_ant.py
import gym
import numpy as np
from gym import spaces
from gym.envs.registration import register


class ImprovedAntEnv(gym.Env):
    """
    A lightweight, MuJoCo-free Ant-style environment (planar).

    - Torso with 4 legs (front/back × left/right), 2 joints per leg (hip, knee)
    - Action space: torques for 8 joints in [-1, 1]
    - Observation (size = 26):
        [0:2]     torso (x, z)
        [2]       torso pitch (rad)
        [3:11]    joint angles (8)
        [11:13]   torso linear vel (x_dot, z_dot)
        [13]      torso ang vel (pitch_dot)
        [14:22]   joint angular vels (8)
        [22:26]   foot contacts [FL, FR, BL, BR] in {0, 1}
    - Noise: action, dynamics, observations (scales configurable)
    - get_state()/set_state() provided for snapshot-based planning
    """

    metadata = {"render.modes": ["human"]}

    def __init__(
        self,
        action_noise_scale: float = 0.03,
        dynamics_noise_scale: float = 0.02,
        obs_noise_scale: float = 0.01,
        max_steps: int = 1000,
    ):
        super().__init__()

        # --- Action/Observation spaces ---
        self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-50.0, high=50.0, shape=(26,), dtype=np.float32)

        # --- Noise params ---
        self.action_noise_scale = float(action_noise_scale)
        self.dynamics_noise_scale = float(dynamics_noise_scale)
        self.obs_noise_scale = float(obs_noise_scale)

        # --- Physics params (simple & stable) ---
        self.dt = 0.04
        self.mass = 5.0
        self.gravity = -9.8

        self.joint_damping = 0.12
        self.joint_coupling = 0.20      # coupling hip↔knee within a leg
        self.ang_momentum_coupling = 0.08
        self.momentum = 0.9

        # Body geometry
        self.torso_width = 0.6
        self.torso_length = 0.8
        self.leg_len = 0.6
        self.hip_offset_x = 0.35        # fore/back hip anchor offset along x
        self.hip_offset_y = 0.5 * self.torso_width  # left/right half-width
        self.ground_z = 0.0

        # Contact / drag
        self.contact_push = 6.0
        self.air_drag = 0.985
        self.ground_drag = 0.65

        # Reward weights
        self.w_fwd = 2.4
        self.alive_bonus = 1.0
        self.h_bonus = 0.35
        self.w_energy = 0.02
        self.w_posture = 0.4
        self.stumble_w = 0.02

        # Termination thresholds
        self.min_height = 0.4
        self.max_torso_angle = np.deg2rad(70.0)

        # Episode bookkeeping
        self.max_steps = int(max_steps)
        self.steps = 0

        # State buffers
        self.torso_pos = np.array([0.0, 1.0], dtype=np.float32)    # (x, z)
        self.torso_vel = np.zeros(2, dtype=np.float32)             # (x_dot, z_dot)
        self.torso_ang = 0.0                                       # pitch
        self.torso_ang_vel = 0.0

        # Joint ordering (8): [FL_hip, FL_knee, FR_hip, FR_knee, BL_hip, BL_knee, BR_hip, BR_knee]
        self.q = np.zeros(8, dtype=np.float32)
        self.qd = np.zeros(8, dtype=np.float32)

        # Contacts: [FL, FR, BL, BR]
        self.contact = np.zeros(4, dtype=np.int32)
        self.prev_foot_xy = np.zeros((4, 2), dtype=np.float32)

        self.reset()

    # -----------------------------
    # Helpers
    # -----------------------------
    def _foot_positions(self):
        """
        Approximate feet (x, z) from hip anchors and leg angles.
        Effective leg angle: hip + 0.7*knee relative to torso pitch.
        """
        x, z = float(self.torso_pos[0]), float(self.torso_pos[1])
        pitch = float(self.torso_ang)

        # Hip anchor positions (planar, left/right mirrored across torso width)
        # FL/FR at +hip_offset_x; BL/BR at -hip_offset_x; left is -y, right is +y (used only to index)
        hip_x = np.array([ x + self.hip_offset_x,  x + self.hip_offset_x,
                           x - self.hip_offset_x,  x - self.hip_offset_x ], dtype=np.float32)
        hip_z = np.array([ z, z, z, z ], dtype=np.float32)

        # Effective angles per leg
        # indices: leg 0=FL (0,1), 1=FR(2,3), 2=BL(4,5), 3=BR(6,7)
        eff = np.empty(4, dtype=np.float32)
        for leg, (hip_i, knee_i) in enumerate([(0,1),(2,3),(4,5),(6,7)]):
            eff[leg] = self.q[hip_i] + 0.7 * self.q[knee_i] + pitch

        # Project leg tip positions
        fx = hip_x + self.leg_len * np.sin(eff)
        fz = hip_z - self.leg_len * np.cos(eff)
        feet = np.stack([fx, fz], axis=1).astype(np.float32)
        return feet

    def _compute_contacts(self, feet):
        return (feet[:, 1] <= self.ground_z + 1e-6).astype(np.int32)

    def _get_obs(self):
        obs = np.zeros(26, dtype=np.float32)
        obs[0:2] = self.torso_pos
        obs[2] = self.torso_ang
        obs[3:11] = self.q
        obs[11:13] = self.torso_vel
        obs[13] = self.torso_ang_vel
        obs[14:22] = self.qd
        obs[22:26] = self.contact

        if self.obs_noise_scale > 0:
            obs += np.random.normal(0.0, self.obs_noise_scale, size=obs.shape).astype(np.float32)
        return obs

    # -----------------------------
    # Snapshot state (pickle-safe)
    # -----------------------------
    def get_state(self):
        return {
            "steps": int(self.steps),
            "torso_pos": self.torso_pos.copy(),
            "torso_vel": self.torso_vel.copy(),
            "torso_ang": float(self.torso_ang),
            "torso_ang_vel": float(self.torso_ang_vel),
            "q": self.q.copy(),
            "qd": self.qd.copy(),
            "contact": self.contact.copy(),
            "prev_foot_xy": self.prev_foot_xy.copy(),
            "action_noise_scale": float(self.action_noise_scale),
            "dynamics_noise_scale": float(self.dynamics_noise_scale),
            "obs_noise_scale": float(self.obs_noise_scale),
            "rng_state": np.random.get_state(),
        }

    def set_state(self, state):
        self.steps = int(state["steps"])
        self.torso_pos = np.array(state["torso_pos"], dtype=np.float32, copy=True)
        self.torso_vel = np.array(state["torso_vel"], dtype=np.float32, copy=True)
        self.torso_ang = float(state["torso_ang"])
        self.torso_ang_vel = float(state["torso_ang_vel"])
        self.q = np.array(state["q"], dtype=np.float32, copy=True)
        self.qd = np.array(state["qd"], dtype=np.float32, copy=True)
        self.contact = np.array(state["contact"], dtype=np.int32, copy=True)
        self.prev_foot_xy = np.array(state["prev_foot_xy"], dtype=np.float32, copy=True)
        self.action_noise_scale = float(state.get("action_noise_scale", self.action_noise_scale))
        self.dynamics_noise_scale = float(state.get("dynamics_noise_scale", self.dynamics_noise_scale))
        self.obs_noise_scale = float(state.get("obs_noise_scale", self.obs_noise_scale))
        if "rng_state" in state:
            np.random.set_state(state["rng_state"])

    # -----------------------------
    # Gym API
    # -----------------------------
    def reset(self):
        self.steps = 0

        self.torso_pos = np.array([0.0, 1.0], dtype=np.float32)
        self.torso_vel = np.zeros(2, dtype=np.float32)
        self.torso_ang = float(np.random.uniform(-0.05, 0.05))
        self.torso_ang_vel = 0.0

        self.q = np.random.uniform(-0.05, 0.05, size=8).astype(np.float32)
        self.qd = np.zeros(8, dtype=np.float32)

        feet = self._foot_positions()
        self.contact = self._compute_contacts(feet)
        self.prev_foot_xy = feet.copy()

        return self._get_obs()

    def step(self, action):
        self.steps += 1

        # --- Action handling (robust to shape) ---
        a = np.asarray(action, dtype=np.float32)
        if a.shape != (8,):
            a = np.reshape(a, (8,)).astype(np.float32, copy=False)

        if self.action_noise_scale > 0:
            a = a + np.random.normal(0.0, self.action_noise_scale, size=a.shape).astype(np.float32)

        # Avoid in-place on potential views
        a = np.clip(a, -1.0, 1.0, out=np.empty_like(a))

        # --- Joint torque & coupling per leg ---
        tau = 3.0 * a
        tau_c = tau.copy()
        # Legs: FL(0,1), FR(2,3), BL(4,5), BR(6,7)
        for hip_i in (0, 2, 4, 6):
            knee_i = hip_i + 1
            # symmetric coupling hip<->knee
            tau_c[hip_i] += self.joint_coupling * tau[knee_i]
            tau_c[knee_i] += self.joint_coupling * tau[hip_i]

        # Angular velocity update with damping
        self.qd += (tau_c - self.joint_damping * self.qd) * self.dt

        # Optional dynamics noise
        if self.dynamics_noise_scale > 0:
            self.qd += np.random.normal(0.0, self.dynamics_noise_scale, size=self.qd.shape).astype(np.float32)
            self.torso_vel += np.random.normal(0.0, self.dynamics_noise_scale, size=self.torso_vel.shape).astype(np.float32)
            self.torso_ang_vel += float(np.random.normal(0.0, self.dynamics_noise_scale))

        # Integrate joint angles with bounds
        self.q += self.qd * self.dt
        self.q = np.clip(self.q, -1.2, 1.2)

        # --- Contact & ground interaction ---
        feet = self._foot_positions()
        self.contact = self._compute_contacts(feet)

        push_forward = 0.0
        push_up = 0.0
        for leg in range(4):
            if self.contact[leg] == 1:
                # vertical bounce support
                if self.torso_vel[1] < 0:
                    push_up += -self.torso_vel[1] * self.mass * 0.4
                # convert leg motion magnitude to forward push
                # index ranges for each leg's joints
                j0 = 2 * leg
                leg_ang_speed = float(np.abs(self.qd[j0:j0+2]).sum())
                push_forward += self.contact_push * 0.18 * leg_ang_speed
                push_up += self.contact_push * 0.08

        # --- Torso dynamics ---
        fwd_drive = self.ang_momentum_coupling * float(np.abs(self.qd).sum()) + push_forward
        up_drive = push_up + self.gravity

        self.torso_vel[0] = self.momentum * self.torso_vel[0] + fwd_drive * self.dt
        self.torso_vel[1] = self.momentum * self.torso_vel[1] + up_drive * self.dt

        drag = self.ground_drag if np.any(self.contact) else self.air_drag
        self.torso_vel *= drag

        # Integrate position; keep above ground softly
        self.torso_pos += self.torso_vel * self.dt
        if self.torso_pos[1] < self.min_height:
            self.torso_pos[1] = self.min_height
            if self.torso_vel[1] < 0:
                self.torso_vel[1] = -0.3 * self.torso_vel[1]

        # Passive torso stabilization around pitch=0
        balance_torque = -0.15 * self.torso_ang - 0.07 * self.torso_ang_vel
        self.torso_ang_vel += balance_torque * self.dt
        self.torso_ang += self.torso_ang_vel * self.dt
        self.torso_ang = float(np.clip(self.torso_ang, -np.pi, np.pi))

        # --- Reward ---
        forward_vel = float(self.torso_vel[0])
        height = float(self.torso_pos[1])
        energy = float(np.dot(a, a))
        posture_pen = float(self.torso_ang ** 2)

        # Foot slip penalty (horizontal slip while in contact)
        slip_pen = 0.0
        new_feet = self._foot_positions()
        for i in range(4):
            if self.contact[i] == 1:
                slip = float(np.abs(new_feet[i, 0] - self.prev_foot_xy[i, 0]))
                slip_pen += slip
        self.prev_foot_xy = new_feet.copy()

        reward = (
            self.w_fwd * forward_vel
            + self.alive_bonus
            + self.h_bonus * height
            - self.w_energy * energy
            - self.w_posture * posture_pen
            - self.stumble_w * slip_pen
        )

        done = (
            (height < 0.25)
            or (np.abs(self.torso_ang) > self.max_torso_angle)
            or (self.steps >= self.max_steps)
        )

        info = {}
        return self._get_obs(), float(reward), bool(done), info

    def render(self, mode="human"):
        if mode == "human":
            feet = self._foot_positions()
            print(
                f"x={self.torso_pos[0]:.2f} z={self.torso_pos[1]:.2f}  "
                f"vx={self.torso_vel[0]:.2f} vz={self.torso_vel[1]:.2f}  "
                f"pitch={np.rad2deg(self.torso_ang):.1f}°  "
                f"contacts={self.contact.tolist()}  "
                f"FL=({feet[0,0]:.2f},{feet[0,1]:.2f})  FR=({feet[1,0]:.2f},{feet[1,1]:.2f})  "
                f"BL=({feet[2,0]:.2f},{feet[2,1]:.2f})  BR=({feet[3,0]:.2f},{feet[3,1]:.2f})"
            )


# Register the environment
register(
    id="ImprovedAnt-v0",
    entry_point="improved_ant:ImprovedAntEnv",
    max_episode_steps=1000,
)


# Optional quick self-test
if __name__ == "__main__":
    env = ImprovedAntEnv()
    obs = env.reset()
    ret = 0.0
    for t in range(300):
        a = env.action_space.sample()
        obs, r, d, _ = env.step(a)
        ret += r
        if d:
            break
    print(f"Quick rollout steps={t+1}, return={ret:.2f}")
